//============================================================================
//
// This file is part of the Thea toolkit.
//
// This software is distributed under the BSD license, as detailed in the
// accompanying LICENSE.txt file. Portions are derived from other works:
// their respective licenses and copyright information are reproduced in
// LICENSE.txt and/or in the relevant source files.
//
// Author: Siddhartha Chaudhuri
// First version: 2011
//
//============================================================================

#ifndef __Thea_Algorithms_Zernike2_hpp__
#define __Thea_Algorithms_Zernike2_hpp__

#include "../Common.hpp"
#include "../IAddressableMatrix.hpp"
#include "../MatVec.hpp"
#include <complex>

namespace Thea {
namespace Algorithms {

/**
 * Compute Zernike moments of a 2D distribution, represented as a matrix of density values. The density values may be
 * multidimensional, i.e. the matrix elements may be vectors or colors.
 *
 * This class adapts code for the LightField Descriptor from Ding-Yun Chen et al.,
 * http://3d.csie.ntu.edu.tw/~dynamic/3DRetrieval/index.html .
 */
class THEA_API Zernike2
{
  public:
    /** The matrix type storing N-dimensional moments. Each column is a moment. */
    template <int N, typename ScalarT> using MomentMatrix = Matrix< N, Eigen::Dynamic, std::complex<ScalarT> >;

    /** %Options for generating Zernike moments. */
    struct THEA_API Options
    {
      int angular_steps;  ///< Number of angular steps.
      int radial_steps;   ///< Number of radial steps.
      int lut_radius;     ///< Radius of Zernike basis function for lookup table.

      /** Constructor. */
      Options(int angular_steps_ = 12, int radial_steps_ = 3, int lut_radius_ = 50)
      : angular_steps(angular_steps_), radial_steps(radial_steps_), lut_radius(lut_radius_)
      {}

      /** Get the set of default options. */
      static Options const & defaults() { static Options const def; return def; }

    }; // struct Options

    /** Constructor. */
    Zernike2(Options const & opts_ = Options::defaults());

    /** Get the number of moments generated by a call to compute(). */
    intx numMoments() const { return opts.angular_steps * opts.radial_steps; }

    /**
     * Compute Zernike moments of a 2D distribution, represented as a matrix of single- or multi-dimensional density values
     * (such as reals, vectors or colors). If the density values have more than one dimension, they should allow array
     * addressing (<code>operator[](intx i)</code>). The template parameter N, inferred from the last argument, must be the same
     * as the number of dimensions of the input.
     *
     * @param distrib The distribution represented as an addressable matrix of density values.
     * @param center_x The x-coordinate (column) of the center of the non-zero region of the distribution, in matrix
     *   coordinates.
     * @param center_y The y-coordinate (row) of the center of the non-zero region of the distribution, in matrix coordinates.
     * @param radius The radius of the non-zero region of the distribution, measured from the center, in matrix coordinates. All
     *   zero elements can be ignored when specifying this number.
     * @param moments Used to return the Zernike moments, specified in "angle-major, radius-minor" order. Each moment is a
     *   column of the matrix.
     *
     * @return The number of pixels that have non-zero values and were used to compute the moments.
     */
    template <int N, typename T, typename ScalarT>
    intx compute(IAddressableMatrix<T> const & distrib, double center_x, double center_y, double radius,
                 MomentMatrix<N, ScalarT> & moments) const;

  private:
    /** Generate lookup table for moments. */
    void generateBasisLUT() const;

    /**
     * Check if an input element (type T) is zero. This is the generic implementation for multi-channel inputs. The
     * single-channel case is separately specialized.
    */
    template <int N> struct IsZero
    {
      template <typename T> static bool check(T const & t)
      {
        for (int i = 0; i < N; ++i)
          if (t[i] != 0) return false;

        return true;
      }
    };

    /**
     * Add a scaled increment to a moment. This is the generic implementation for multi-channel inputs. The single-channel case
     * is separately specialized.
     */
    template <int N, typename ScalarT> struct Accum
    {
      template <typename T>
      static void add(T const & t, std::complex<double> const & x, typename MomentMatrix<N, ScalarT>::ColXpr acc)
      {
        for (intx i = 0; i < acc.size(); ++i)
        {
          acc[i].real(acc[i].real() + static_cast<ScalarT>(x.real() * t[i]));
          acc[i].imag(acc[i].imag() - static_cast<ScalarT>(x.imag() * t[i]));
        }
      }
    };

    /** Lookup table class (4D grid). */
    class LUT
    {
      public:
        typedef std::complex<double> value_type;  ///< Type of values in the grid.

        /** Default constructor. */
        LUT() : extents{0, 0, 0, 0} {}

        /** Resize the grid to given dimensions. */
        void resize(intx d0, intx d1, intx d2, intx d3)
        {
          values.resize((size_t)(d0 * d1 * d2 * d3));
          extents[0] = d0;
          extents[1] = d1;
          extents[2] = d2;
          extents[3] = d3;
        }

        /** Access a read-only element. */
        value_type const & operator()(intx i0, intx i1, intx i2, intx i3) const
        {
          return const_cast<LUT &>(*this)(i0, i1, i2, i3);
        }

        /** Access a read-write element. */
        value_type & operator()(intx i0, intx i1, intx i2, intx i3)
        {
          return values[(size_t)(((i0 * extents[0] + i1) * extents[1] + i2) * extents[2] + i3)];
        }

      private:
        Array<value_type> values;
        intx extents[4];

    }; // class LUT

    Options opts;                ///< Set of options.
    mutable LUT lut;             ///< Coefficient lookup table.
    mutable bool lut_generated;  ///< Has the LUT been generated?

}; // class Zernike2

// Specializations for single-channel input.
template <>
struct Zernike2::IsZero<1>
{
  template <typename T> static bool check(T const & t)
  {
    return t != 0;
  }
};

template <typename ScalarT>
struct Zernike2::Accum<1, ScalarT>
{
  template <typename T>
  static void add(T const & t, std::complex<double> const & x, typename MomentMatrix<1, ScalarT>::ColXpr & acc)
  {
    acc[0].real(acc[0].real() + static_cast<ScalarT>(x.real() * t));
    acc[0].imag(acc[0].imag() - static_cast<ScalarT>(x.imag() * t));
  }
};

template <int N, typename T, typename ScalarT>
intx
Zernike2::compute(IAddressableMatrix<T> const & distrib, double center_x, double center_y, double radius,
                  Zernike2::MomentMatrix<N, ScalarT> & moments) const
{
  alwaysAssertM(radius > 0, "Zernike2: Radius must be greater than zero");

  this->generateBasisLUT();

  moments.resize(Eigen::NoChange, numMoments());
  moments.setZero();

  intx ncols = distrib.cols();
  intx nrows = distrib.rows();

  // Don't go outside the specified radius
  intx min_x = std::max(0L,        (intx)std::ceil (center_x - radius));
  intx max_x = std::min(ncols - 1, (intx)std::floor(center_x + radius));

  intx min_y = std::max(0L,        (intx)std::ceil (center_y - radius));
  intx max_y = std::min(nrows - 1, (intx)std::floor(center_y + radius));

  double r_radius = opts.lut_radius / radius;

  std::complex<double> x1, x2, x3;
  double dx, dy, tx, ty;
  intx ix, iy;
  intx count = 0;
  for (intx y = min_y; y <= max_y; ++y)
  {
    for (intx x = min_x; x <= max_x; ++x)
    {
      T const & density = distrib.at(y, x);
      if (!IsZero<N>::check(density))
      {
        dx = x - center_x;
        dy = y - center_y;
        tx = dx * r_radius + opts.lut_radius;
        ty = dy * r_radius + opts.lut_radius;
        ix = (intx)tx;
        iy = (intx)ty;
        dx = tx - ix;
        dy = ty - iy;

        // Summation of basis function
        for (intx p = 0; p < opts.angular_steps; ++p)
        {
          for (intx r = 0; r < opts.radial_steps; ++r)
          {
            x1 = lut(p, r, ix, iy    ) + (lut(p, r, ix + 1, iy    ) - lut(p, r, ix, iy    )) * dx;
            x2 = lut(p, r, ix, iy + 1) + (lut(p, r, ix + 1, iy + 1) - lut(p, r, ix, iy + 1)) * dx;
            x3 = x1 + (x2 - x1) * dy;

            Accum<N, ScalarT>::add(density, x3, moments.col(p * opts.radial_steps + r));
          }
        }

        count++;
      }
    }
  }

  if (count > 0)
    moments /= (ScalarT)count;

  return count;
}

} // namespace Algorithms
} // namespace Thea

#endif
